{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "680b19b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import os\n",
    "import pandas as pd\n",
    "import time\n",
    "\n",
    "from processing.load_datasets import load_rts\n",
    "from processing.configs import COLORS_10\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "from umap import UMAP\n",
    "from trimap import TRIMAP\n",
    "from pacmap import PaCMAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b993293",
   "metadata": {},
   "outputs": [],
   "source": [
    "rts_df = load_rts(merge_metadata=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93c4f852",
   "metadata": {},
   "outputs": [],
   "source": [
    "sports_dict = {\n",
    "    \"Football\": [\n",
    "        \"Football en direct\",\n",
    "        \"Football enregistrement\"\n",
    "    ],\n",
    "    \"Tennis\": [\n",
    "        \"Tennis en direct\"\n",
    "    ],\n",
    "    \"Alpine skiing and snowboarding\": [\n",
    "        \"Ski alpin, snowboard en direct\",\n",
    "        \"Ski alpin, snowboard enregistrement\"\n",
    "    ],\n",
    "    \"Cross country skiing\": [\n",
    "        \"Ski de fond en direct\"\n",
    "    ],\n",
    "    \"Ice hockey\": [\n",
    "        \"Hockey sur glace en direct\"\n",
    "    ],\n",
    "    \"Swimming\": [\n",
    "        \"Natation en direct\"\n",
    "    ],\n",
    "    \"Volleyball\": [\n",
    "        \"Volley-ball en direct\"\n",
    "    ],\n",
    "    \"Athletics\": [\n",
    "        \"Athlétisme en direct\",\n",
    "        \"Athlétisme enregistrement\"\n",
    "    ],\n",
    "    \"Cycling\": [\n",
    "        \"Cyclisme en direct\"\n",
    "    ],\n",
    "    \"Motorcycling\": [\n",
    "        \"Motocyclisme en direct\"\n",
    "    ]\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0bcbc37",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_sport_class(label, sports_dict, default=None):\n",
    "    \"\"\"\n",
    "    Map a RTS sport class label to a canonical sport category.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    label : str\n",
    "        Original RTS class.\n",
    "    sports_dict : dict\n",
    "        Dictionary of canonical sport -> list of RTS labels.\n",
    "    default : any\n",
    "        Value returned if no match is found.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    str or default\n",
    "        Canonical sport category.\n",
    "    \"\"\"\n",
    "    if label is None:\n",
    "        return default\n",
    "\n",
    "    for sport, labels in sports_dict.items():\n",
    "        if label in labels:\n",
    "            return sport\n",
    "\n",
    "    return default\n",
    "\n",
    "rts_df[\"sport\"] = rts_df[\"contentType\"].apply(lambda x: map_sport_class(x, sports_dict))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5be69e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "rts_df = rts_df.dropna(subset=[\"sport\"])\n",
    "print(f\"Total samples after filtering: {len(rts_df)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e7ce79a",
   "metadata": {},
   "outputs": [],
   "source": [
    "c = 100\n",
    "\n",
    "dr_algos = {\n",
    "    \"tsne\": TSNE(n_components=2, perplexity=50),\n",
    "    \"umap\": UMAP(n_components=2, min_dist=0.5, n_neighbors=30),\n",
    "    \"trimap\": TRIMAP(n_dims=2, n_inliers=2*c, n_outliers=c, n_random=c),\n",
    "    \"pacmap\": PaCMAP(n_components=2, n_neighbors=30, MN_ratio=5.0, FP_ratio=5.0)\n",
    "}\n",
    "\n",
    "algo_names = list(dr_algos.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13f26aa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "N_SAMPLE = 10000\n",
    "rts_sample = rts_df.sample(N_SAMPLE, random_state=42)\n",
    "X = np.stack(rts_sample[\"imagenet_features\"].values)\n",
    "\n",
    "# Save labels\n",
    "labels_path = \"data/rts_sports_labels.csv\"\n",
    "if not os.path.exists(labels_path):\n",
    "    rts_sample[\"sport\"].to_csv(labels_path, index=False)\n",
    "\n",
    "dr_results = {}\n",
    "for algo_name, algo in dr_algos.items():\n",
    "    output_path = f\"embeddings/rts_sports/{algo_name}.npy\"\n",
    "    if os.path.exists(output_path):\n",
    "        print(f\"Loading existing embedding for {algo_name}...\")\n",
    "        X_dr = np.load(output_path)\n",
    "        dr_results[algo_name] = {\n",
    "            \"embedding\": X_dr,\n",
    "            \"time\": None\n",
    "        }\n",
    "        continue\n",
    "\n",
    "    print(f\"Running {algo_name}...\")\n",
    "    start_time = time.time()\n",
    "    X_dr = algo.fit_transform(X)\n",
    "    end_time = time.time()\n",
    "    dr_results[algo_name] = {\n",
    "        \"embedding\": X_dr,\n",
    "        \"time\": end_time - start_time\n",
    "    }\n",
    "\n",
    "    # Save embeddings\n",
    "    np.save(output_path, X_dr)\n",
    "\n",
    "    print(f\"{algo_name} completed in {end_time - start_time:.2f} seconds\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c7219f4",
   "metadata": {},
   "source": [
    "# Visualisation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8c67b46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting the results colored by content type\n",
    "sport_labels = pd.read_csv(\"data/rts_sports_labels.csv\")\n",
    "sport_categories = list(sports_dict.keys())\n",
    "color_map = {sport: COLORS_10[i] for i, sport in enumerate(sport_categories)}\n",
    "color_list = [color_map[sport] for sport in sport_labels[\"sport\"]]\n",
    "\n",
    "padding_ratio = 0.1  # 5% padding around the embedding\n",
    "\n",
    "titles = [\n",
    "    \"t-SNE: perplexity=50\",\n",
    "    \"UMAP: min_dist=0.5, n_neighbors=30\",\n",
    "    \"TriMap: n_inliers=2c, n_outliers=c, n_random=c (c=100)\",\n",
    "    \"PaCMAP: n_neighbors=30, MN_ratio=5.0, FP_ratio=5.0\"\n",
    "]\n",
    "\n",
    "fig, axes = plt.subplots(2, 2, figsize=(15, 15))\n",
    "\n",
    "for i, algo_name in enumerate(algo_names):\n",
    "    ax = axes[i // 2, i % 2]\n",
    "    X_dr = dr_results[algo_name][\"embedding\"]\n",
    "\n",
    "    ax.scatter(X_dr[:, 0], X_dr[:, 1], c=color_list, alpha=0.6, s=5)\n",
    "\n",
    "    # --- Square + padding ---\n",
    "    x_min, x_max = X_dr[:, 0].min(), X_dr[:, 0].max()\n",
    "    y_min, y_max = X_dr[:, 1].min(), X_dr[:, 1].max()\n",
    "\n",
    "    max_range = max(x_max - x_min, y_max - y_min)\n",
    "\n",
    "    # Add padding\n",
    "    padded_range = max_range * (1 + padding_ratio)\n",
    "\n",
    "    x_center = (x_max + x_min) / 2\n",
    "    y_center = (y_max + y_min) / 2\n",
    "\n",
    "    ax.set_xlim(x_center - padded_range / 2, x_center + padded_range / 2)\n",
    "    ax.set_ylim(y_center - padded_range / 2, y_center + padded_range / 2)\n",
    "\n",
    "    ax.set_aspect(\"equal\")\n",
    "\n",
    "    ax.set_title(f\"{titles[i]}\", fontsize=16)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([])\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "# Legend\n",
    "handles = [mpatches.Patch(color=color_map[sport], label=sport) for sport in sport_categories]\n",
    "fig.legend(handles=handles, loc=\"lower center\", bbox_to_anchor=(0.5, 0), ncol=5, frameon=False)\n",
    "\n",
    "# Reserve space at bottom\n",
    "plt.subplots_adjust(bottom=0.05)\n",
    "\n",
    "plt.savefig(\"images/rts_sports_comparison.png\", dpi=300, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e493c3f0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dr_eval",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
